"""Abstract prop."""
import warnings
from pathlib import Path
from typing import Union, Iterable, Optional

import numpy as np
from dm_control import mjcf
from mojo import Mojo
from mojo.elements import Body, Geom, Site, MujocoElement
from mujoco_utils import mjcf_utils
from pyquaternion import Quaternion

from bigym.utils.physics_utils import has_collided_collections, get_colliders


class Prop:
    """Abstract prop."""

    _HIDDEN_POSITION = np.array([0, 0, -100])

    def __init__(
        self,
        mojo: Mojo,
        model_path: Path,
        kinematic: bool = False,
        cache_colliders: bool = False,
        cache_sites: bool = False,
        parent: Optional[MujocoElement] = None,
    ):
        """Init."""
        self._mojo = mojo
        self.body: Body = self._mojo.load_model(
            str(model_path), on_loaded=self._process_model, parent=parent
        )
        self.colliders: list[Geom] = (
            self.get_colliders(self.body) if cache_colliders else []
        )
        self.sites: list[Site] = []
        if cache_sites:
            try:
                site_elements = mjcf_utils.safe_find_all(self.body.mjcf, "site")
                self.sites = [
                    Site.get(self._mojo, site.full_identifier) for site in site_elements
                ]
            except ValueError as error:
                warnings.warn(str(error))
        self.is_kinematic = kinematic
        if self.is_kinematic:
            self.body.set_kinematic(True)

        self._geoms = self.body.geoms
        self._geoms_settings: dict[mjcf.Element, (int, int)] = {}
        for geom in self._geoms:
            self._geoms_settings[geom.mjcf] = (
                self._mojo.physics.bind(geom.mjcf).contype,
                self._mojo.physics.bind(geom.mjcf).conaffinity,
            )

    def _process_model(self, model: mjcf.RootElement):
        self._on_loaded(model)

    def _on_loaded(self, model: mjcf.RootElement):
        """Callback to customize prop model."""
        pass

    def get_pose(self) -> np.ndarray:
        """Get pose in the world space."""
        return np.concatenate(
            (self.body.get_position(), self.body.get_quaternion()),
            axis=-1,
        )

    def set_pose(
        self,
        position: np.ndarray = np.zeros(3),
        quat: np.ndarray = Quaternion().elements,
        position_bounds: np.ndarray = np.zeros(3),
        rotation_bounds: np.ndarray = np.zeros(3),
    ):
        """Set pose in the world space."""
        offset_pos = np.random.uniform(-position_bounds, position_bounds)
        pos = position + offset_pos

        offset_rot = np.random.uniform(-rotation_bounds, rotation_bounds)
        quat = (
            Quaternion(quat)
            * Quaternion(axis=[1, 0, 0], angle=offset_rot[0])
            * Quaternion(axis=[0, 1, 0], angle=offset_rot[1])
            * Quaternion(axis=[0, 0, 1], angle=offset_rot[2])
        )

        self.body.set_position(pos, True)
        self.body.set_quaternion(quat.elements, True)

    def disable(self):
        """Disable prop."""
        for geom in self._geoms:
            geom = self._mojo.physics.bind(geom.mjcf)
            geom.contype = 0
            geom.conaffinity = 0
        if self.body.is_kinematic():
            freejoint = self._mojo.physics.bind(self.body.mjcf.freejoint)
            freejoint.damping = 10e6
        self.body.set_position(self._HIDDEN_POSITION, True)

    def enable(self):
        """Enable prop."""
        for geom in self._geoms:
            contype, conaffinity = self._geoms_settings[geom.mjcf]
            geom = self._mojo.physics.bind(geom.mjcf)
            geom.contype = contype
            geom.conaffinity = conaffinity
        if self.body.is_kinematic():
            freejoint = self._mojo.physics.bind(self.body.mjcf.freejoint)
            freejoint.damping = 0

    def get_velocities(self) -> np.ndarray:
        """Get velocities of the free body."""
        if self.is_kinematic:
            return np.array(self._mojo.physics.bind(self.body.mjcf.freejoint).qvel)
        else:
            return np.zeros(0)

    def is_static(self, atol_pos: float = 1.0e-3):
        """Check if object is not moving."""
        velocities = self.get_velocities()
        return np.allclose(velocities[:3], 0, atol=atol_pos)

    def is_colliding(self, other: Union[Geom, Iterable[Geom], "Prop"]) -> bool:
        """Check collision between two props."""
        other_colliders = get_colliders(other)
        return has_collided_collections(
            self._mojo.physics, self.colliders, other_colliders
        )

    @staticmethod
    def get_colliders(body: Body) -> list[Geom]:
        """Get all colliders of the body."""
        return [g for g in body.geoms if g.is_collidable()]
